package xsuite

import (
	"context"
	"database/sql"
	"gitee.com/xfrm/middleware/xsql/xdb"

	"gitee.com/xfrm/middleware/xsql/manager"

	"github.com/DATA-DOG/go-sqlmock"
	"github.com/agiledragon/gomonkey"
	"github.com/stretchr/testify/suite"
	"github.com/zhulongcheng/testsql"
)

// LogicTestSqlSuite test logic file, make a database to test sql logic
type LogicTestSqlSuite struct {
	suite.Suite

	DB      *sql.DB           // DB 模拟库的DB连接
	TS      *testsql.TestSQL  // TestSQL 用于生成/清除 测试数据
	Patches *gomonkey.Patches // Patches 补丁，可以reset

	DSN             string // see https://github.com/go-sql-driver/mysql#dsn-data-source-name 数据源名称
	TableSchemaPath string // schema文件路径, 默认寻找logic同级目录，testsql目录, schema.sql文件
	FixtureDirPath  string // 准备数据文件夹路径, 默认为 testsql目录中，fixture目录
}

func (p *LogicTestSqlSuite) SetupSuite() {
	if p.DSN == "" {
		p.DSN = defaultTestSqlDSN
	}
	if p.TableSchemaPath == "" {
		p.TableSchemaPath = "../testsql/schema.sql"
	}
	if p.FixtureDirPath == "" {
		p.FixtureDirPath = "../testsql/fixtures"
	}
	ts := testsql.New(p.DSN, p.TableSchemaPath, p.FixtureDirPath)
	p.TS = ts
	p.DB = ts.DB

	gomonkey.ApplyFunc(manager.GetDB, func(ctx context.Context, cluster, table string) (*xdb.DB, error) {
		xdb := new(xdb.DB)
		xdb.SetSQLDB(ts.DB)
		return xdb, nil
	})
	p.Patches = gomonkey.NewPatches()
}

func (p *LogicTestSqlSuite) TearDownSuite() {
	p.TS.DropTestDB()
	p.DB.Close()
	p.Patches.Reset()
}

// LogicMockDBSuite test logic file, db use mock db
type LogicMockDBSuite struct {
	suite.Suite

	Mock    sqlmock.Sqlmock   // Sqlmock 用来设置期望值
	DB      *sql.DB           // mock db连接
	Patches *gomonkey.Patches // Patches 补丁，可以reset
}

func (p *LogicMockDBSuite) SetupSuite() {
	db, mock, err := sqlmock.New()
	p.NoError(err)

	p.Mock = mock
	p.DB = db

	gomonkey.ApplyFunc(manager.GetDB, func(ctx context.Context, cluster, table string) (*xdb.DB, error) {
		xdb := new(xdb.DB)
		xdb.SetSQLDB(db)
		return xdb, nil
	})
	p.Patches = gomonkey.NewPatches()
}

func (p *LogicMockDBSuite) TearDownSuite() {
	p.DB.Close()
	p.Patches.Reset()
}
